#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from typing import Union

import torch.cuda
import torch.distributed as dist
from torch import Tensor

from colossalai.communication import *
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.zero import (ZeroRedundancyOptimizer_Level_2,
                             ZeroRedundancyOptimizer_Level_3)
from colossalai.utils import get_current_device, switch_virtual_pipeline_parallel_rank
from ._base_schedule import BaseSchedule


def squeeze(x: Union[Tensor, tuple, list]):
    if isinstance(x, (tuple, list)):
        return x[0]
    else:
        return x


class PipelineSchedule(BaseSchedule):
    """A helper schedule class for pipeline parallelism running environment.
    It uses non-interleaved 1F1B strategy. Other properties are similar as
    :class:`NonPipelineSchedule`.

    :param num_microbatches: The number of microbatches
    :param amp_type: The type of automatic mixed precision
    :param amp_config: The configuration of automatic mixed procision
    :param sync_data: If set to `True`, will sync data every batch over pipeline stages
    :type num_microbatches: int
    :type amp_type: AMP_TYPE
    :type amp_config: dict
    :type sync_data: bool
    """

    def __init__(self,
                 num_microbatches,
                 sync_data: bool = True):
        super().__init__()

        self.num_microbatches = num_microbatches
        self.sync_data = sync_data
        self.dtype = torch.float

    def _move_to_device(self, data):
        if isinstance(data, (
                tuple,
                list,
        )):
            assert len(data) == 1, "Data tuple's length in pipeline should be 1"
            data = data[0]
        assert torch.is_tensor(data), "Data in pipeline should be tensor"
        data = data.to(get_current_device()).detach()
        return data

    def _sync_data(self):
        reqs = []
        if gpc.is_first_rank(ParallelMode.PIPELINE):
            src_rank = gpc.get_global_rank()
            reqs.append(dist.broadcast(
                tensor=self.batch_data,
                src=src_rank,
                group=gpc.get_group(ParallelMode.PIPELINE_PREV),
                async_op=True
            ))
            reqs.append(dist.broadcast(
                tensor=self.batch_label,
                src=src_rank,
                group=gpc.get_group(ParallelMode.PIPELINE_PREV),
                async_op=True
            ))
        if gpc.is_last_rank(ParallelMode.PIPELINE):
            src_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
            reqs.append(dist.broadcast(
                tensor=self.batch_data,
                src=src_rank,
                group=gpc.get_group(ParallelMode.PIPELINE_NEXT),
                async_op=True
            ))
            reqs.append(dist.broadcast(
                tensor=self.batch_label,
                src=src_rank,
                group=gpc.get_group(ParallelMode.PIPELINE_NEXT),
                async_op=True
            ))
        for req in reqs:
            req.wait()

    # Pipeline schedule just puts data in memory
    def load_batch(self, data_iter):
        if data_iter is None:
            raise RuntimeError('Dataloader is not defined.')
        self.batch_pos = 0
        data, label = next(data_iter)
        self.batch_data, self.batch_label = \
            self._move_to_device(data), self._move_to_device(label)
        batch_size = self.batch_data.shape[0]
        assert batch_size % self.num_microbatches == 0, \
            "Batch size should divided by the number of microbatches"
        self.microbatch_size = batch_size // self.num_microbatches
        if self.sync_data:
            self._sync_data()

    def _get_data_slice(self, tensor):
        return tensor[self.batch_pos: self.batch_pos + self.microbatch_size]

    def load_micro_batch(self):
        data = self._get_data_slice(self.batch_data)
        label = self._get_data_slice(self.batch_label)
        self.batch_pos += self.microbatch_size
        return (data,), (label,)

    def pre_processing(self, engine):
        if isinstance(engine.optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
            raise TypeError(
                "Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
            )

        if isinstance(engine.model, NaiveAMPModel):
            self.dtype = torch.half

    def forward_step(self, engine, input_tensor, return_tensors, return_loss=True):
        """Forward step for passed-in model. If it is the first stage, the input tensor 
        is obtained from data_iterator, otherwise the passed-in input_tensor is used.
        Returns output tensor. This is a helper function and can be ignored by users.

        :param engine: your engine object
        :type engine: colossalai.engine.Engine
        :param input_tensor: input tensor for this pipeline stage
        :type input_tensor: :class:`torch.Tensor`
        :param return_tensors: a list of tensors to return
        :type return_tensors: List[:class:`torch.Tensor`]

        :return: output or the loss value of the current pipeline stage
        :rtype: :class:`torch.Tensor`
        """

        if input_tensor is None:
            input_tensor, label = self.load_micro_batch()
        input_tensor = squeeze(input_tensor)
        output_tensor = engine(input_tensor)
        output_tensor = squeeze(output_tensor)

        if gpc.is_last_rank(ParallelMode.PIPELINE):
            if return_loss:
                input_tensor, label = self.load_micro_batch()
                loss_reduced = engine.criterion(output_tensor, *label) \
                    / self.num_microbatches

                return_tensors.append(
                    tuple((output_tensor, label[0], loss_reduced)))
                return loss_reduced
            else:
                return_tensors.append(output_tensor)
                return output_tensor

        else:
            return output_tensor

    def backward_step(self, engine, input_tensor, output_tensor, output_tensor_grad):
        """Backward step through the passed-in output tensor. If it is the last stage, the 
        output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor.
        Returns the gradients with respect to the input tensor (None if first stage).
        This is a helper function and can be ignored by users.

        :param engine: your engine object
        :type engine: colossalai.engine.Engine
        :param input_tensor: input tensor for this pipeline stage
        :type input_tensor: :class:`torch.Tensor`
        :param output_tensor: output tensor for this pipeline stage
        :type output_tensor: :class:`torch.Tensor`
        :param output_tensor_grad: gradient of output tensor for this pipeline stage
        :type output_tensor_grad: :class:`torch.Tensor`

        :return: gradient of input tensor
        :rtype: :class:`torch.Tensor`
        """

        # Retain the grad on the input_tensor.
        if input_tensor is not None:
            input_tensor.retain_grad()

        # Backward pass.
        if output_tensor_grad is None:
            engine.backward(output_tensor)
        else:
            engine.backward_by_grad(output_tensor, output_tensor_grad)

        # Collect the grad of the input_tensor.
        input_tensor_grad = None
        if input_tensor is not None:
            input_tensor_grad = input_tensor.grad

        return input_tensor_grad

    def forward_backward_step(self,
                              engine,
                              data_iter,
                              forward_only=False,
                              return_loss=True):
        """Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
        Returns a tuple with losses if the last stage, an empty tuple otherwise.

        :param engine: your engine object
        :type engine: colossalai.engine.Engine
        :param data_iter: dataloader as the form of an iterator, obtained by calling iter(dataloader)
        :type data_iter: Iterable
        :param forward_only: whether run forward step only. Default is false. If true, no backward will be run.
        :type forward_only: bool
        :param return_loss: whether returns the loss value. Default is true.
        :type return_loss: bool

        :return: (output, label, loss)
        :rtype: Tuple[:class:`torch.Tensor`]
        """

        assert forward_only or return_loss, \
            'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'

        self.load_batch(data_iter)
        num_warmup_microbatches = \
            (gpc.get_world_size(ParallelMode.PIPELINE) -
             gpc.get_local_rank(ParallelMode.PIPELINE) - 1)
        num_warmup_microbatches = min(num_warmup_microbatches,
                                      self.num_microbatches)
        num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches

        # Input, output tensors only need to be saved when doing backward passes
        input_tensors = None
        output_tensors = None
        if not forward_only:
            input_tensors = []
            output_tensors = []
        return_tensors = []

        # Used for tensor meta information communication
        ft_shape = None
        bt_shape = None
        fs_checker = True

        # Run warmup forward passes.
        for i in range(num_warmup_microbatches):
            if not gpc.is_first_rank(ParallelMode.PIPELINE):
                ft_shape = recv_tensor_meta(ft_shape)
            input_tensor = recv_forward(ft_shape, dtype=self.dtype)
            output_tensor = self.forward_step(
                engine, input_tensor, return_tensors,
                return_loss=return_loss
            )
            if not gpc.is_last_rank(ParallelMode.PIPELINE):
                bt_shape = output_tensor.shape
                fs_checker = send_tensor_meta(output_tensor, fs_checker)
            send_forward(output_tensor)

            if not forward_only:
                input_tensors.append(input_tensor)
                output_tensors.append(output_tensor)

        # Before running 1F1B, need to receive first forward tensor.
        # If all microbatches are run in warmup / cooldown phase, then no need to
        # receive this tensor here.
        if num_microbatches_remaining > 0:
            if not gpc.is_first_rank(ParallelMode.PIPELINE):
                ft_shape = recv_tensor_meta(ft_shape)
            input_tensor = recv_forward(ft_shape, dtype=self.dtype)

        # Run 1F1B in steady state.
        for i in range(num_microbatches_remaining):
            last_iteration = (i == (num_microbatches_remaining - 1))

            output_tensor = self.forward_step(
                engine, input_tensor, return_tensors,
                return_loss=return_loss
            )
            if forward_only:
                send_forward(output_tensor)

                if not last_iteration:
                    input_tensor = recv_forward(ft_shape, dtype=self.dtype)

            else:
                output_tensor_grad = send_forward_recv_backward(
                    output_tensor, bt_shape, dtype=self.dtype)

                # Add input_tensor and output_tensor to end of list.
                input_tensors.append(input_tensor)
                output_tensors.append(output_tensor)

                # Pop input_tensor and output_tensor from the start of the list for
                # the backward pass.
                input_tensor = input_tensors.pop(0)
                output_tensor = output_tensors.pop(0)

                input_tensor_grad = self.backward_step(
                    engine,
                    input_tensor, output_tensor,
                    output_tensor_grad
                )

                if last_iteration:
                    input_tensor = None
                    send_backward(input_tensor_grad)
                else:
                    input_tensor = send_backward_recv_forward(
                        input_tensor_grad, ft_shape, dtype=self.dtype)

        # Run cooldown backward passes.
        if not forward_only:
            for i in range(num_warmup_microbatches):
                input_tensor = input_tensors.pop(0)
                output_tensor = output_tensors.pop(0)

                output_tensor_grad = recv_backward(bt_shape, dtype=self.dtype)

                input_tensor_grad = self.backward_step(
                    engine,
                    input_tensor, output_tensor,
                    output_tensor_grad
                )

                send_backward(input_tensor_grad)

        if len(return_tensors) > 0:
            if return_loss:
                output, label, loss = tuple(map(list, zip(*return_tensors)))
                return (torch.cat(output, dim=0),
                        torch.cat(label, dim=0),
                        sum(loss))
            else:
                return tuple((torch.cat(return_tensors, dim=0), None, None))
        else:
            return tuple((None, None, None))


class InterleavedPipelineSchedule(PipelineSchedule):
    def __init__(self, num_microbatches, num_model_chunks, sync_data: bool = True):
        assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \
            'num_microbatches must be an integer multiple of pipeline parallel world size'
        super().__init__(num_microbatches, sync_data=sync_data)
        gpc.set_virtual_pipeline_parallel_size(num_model_chunks)
        gpc.set_virtual_pipeline_parallel_rank(0)

    def pre_processing(self, engine):
        if isinstance(engine.optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
            raise TypeError(
                "Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
            )

        if isinstance(engine.model[0], NaiveAMPModel):
            self.dtype = torch.half

    def forward_step(self, engine, model, input_tensor, return_tensors, return_loss=True):
        """Forward step for passed-in model. If it is the first stage, the input tensor 
        is obtained from data_iterator, otherwise the passed-in input_tensor is used.
        Returns output tensor. This is a helper function and can be ignored by users.
        """

        if input_tensor is None:
            input_tensor, label = self.load_micro_batch()
        input_tensor = squeeze(input_tensor)
        output_tensor = model(input_tensor)
        output_tensor = squeeze(output_tensor)

        if gpc.is_pipeline_last_stage():
            if return_loss:
                input_tensor, label = self.load_micro_batch()
                loss_reduced = engine.criterion(output_tensor, *label) / self.num_microbatches
                return_tensors.append(
                    tuple((output_tensor, label[0], loss_reduced)))
                return loss_reduced
            else:
                return_tensors.append(output_tensor)
                return output_tensor
        else:
            return output_tensor

    def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True):
        """Run interleaved 1F1B schedule (model split into model chunks), with
        communication between pipeline stages as needed.

        Returns dictionary with losses if the last stage, empty dict otherwise."""
        assert forward_only or return_loss, \
            'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
        self.load_batch(data_iter)
        model = engine.model
        input_tensors = [[] for _ in range(len(model))]
        output_tensors = [[] for _ in range(len(model))]
        return_tensors = []
        if not forward_only:
            output_tensor_grads = [[] for _ in range(len(model))]

        # Used for tensor meta information communication
        input_tensor_shapes = [None for _ in range(len(model))]
        output_tensor_shapes = [None for _ in range(len(model))]
        send_tensor_shape_flags = [True for _ in range(len(model))]

        pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
        pipeline_parallel_rank = gpc.get_local_rank(ParallelMode.PIPELINE)

        # Compute number of warmup and remaining microbatches.
        num_model_chunks = len(model)
        num_microbatches = self.num_microbatches * num_model_chunks
        all_warmup_microbatches = False
        if forward_only:
            num_warmup_microbatches = num_microbatches
        else:
            # Run all forward passes and then all backward passes if number of
            # microbatches is just the number of pipeline stages.
            # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
            # all workers, followed by more microbatches after depending on
            # stage ID (more forward passes for earlier stages, later stages can
            # immediately start with 1F1B).
            if self.num_microbatches == pipeline_parallel_size:
                num_warmup_microbatches = num_microbatches
                all_warmup_microbatches = True
            else:
                num_warmup_microbatches = \
                    (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
                num_warmup_microbatches += (
                    num_model_chunks - 1) * pipeline_parallel_size
                num_warmup_microbatches = min(num_warmup_microbatches,
                                              num_microbatches)
        num_microbatches_remaining = \
            num_microbatches - num_warmup_microbatches

        def get_model_chunk_id(microbatch_id, forward):
            """Helper method to get the model chunk ID given the iteration number."""
            microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
            model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
            if not forward:
                model_chunk_id = (num_model_chunks - model_chunk_id - 1)
            return model_chunk_id

        def forward_step_helper(microbatch_id):
            """Helper method to run forward step with model split into chunks
            (run set_virtual_pipeline_model_parallel_rank() before calling
            forward_step())."""
            model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
            gpc.set_virtual_pipeline_parallel_rank(model_chunk_id)

            # forward step
            if gpc.is_pipeline_first_stage():
                if len(input_tensors[model_chunk_id]) == \
                        len(output_tensors[model_chunk_id]):
                    input_tensors[model_chunk_id].append(None)
            input_tensor = input_tensors[model_chunk_id][-1]
            output_tensor = self.forward_step(
                engine, model[model_chunk_id], input_tensor, return_tensors, return_loss=return_loss)
            output_tensors[model_chunk_id].append(output_tensor)

            # if forward-only, no need to save tensors for a backward pass
            if forward_only:
                input_tensors[model_chunk_id].pop()
                output_tensors[model_chunk_id].pop()

            return output_tensor

        def backward_step_helper(microbatch_id):
            """Helper method to run backward step with model split into chunks
            (run set_virtual_pipeline_model_parallel_rank() before calling
            backward_step())."""
            model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
            gpc.set_virtual_pipeline_parallel_rank(model_chunk_id)

            if gpc.is_pipeline_last_stage():
                if len(output_tensor_grads[model_chunk_id]) == 0:
                    output_tensor_grads[model_chunk_id].append(None)
            input_tensor = input_tensors[model_chunk_id].pop(0)
            output_tensor = output_tensors[model_chunk_id].pop(0)
            output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
            input_tensor_grad = self.backward_step(engine, input_tensor, output_tensor, output_tensor_grad)

            return input_tensor_grad

        # Run warmup forward passes.
        gpc.set_virtual_pipeline_parallel_rank(0)
        if not gpc.is_pipeline_first_stage():
            input_tensor_shapes[0] = recv_tensor_meta(input_tensor_shapes[0])
        input_tensors[0].append(recv_forward(input_tensor_shapes[0], dtype=self.dtype))

        for k in range(num_warmup_microbatches):
            model_chunk_id = get_model_chunk_id(k, forward=True)
            output_tensor = forward_step_helper(k)
            if not gpc.is_pipeline_last_stage():
                output_tensor_shapes[model_chunk_id] = output_tensor.shape
                send_tensor_shape_flags[model_chunk_id] = send_tensor_meta(
                    output_tensor, send_tensor_shape_flags[model_chunk_id])
            # Determine if tensor should be received from previous stage.
            next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
            recv_prev = True
            if gpc.is_pipeline_first_stage(ignore_virtual=True):
                if next_forward_model_chunk_id == 0:
                    recv_prev = False
            if k == (num_microbatches - 1):
                recv_prev = False

            # Don't send tensor downstream if on last stage.
            if gpc.is_pipeline_last_stage():
                output_tensor = None

            with switch_virtual_pipeline_parallel_rank(next_forward_model_chunk_id):
                if not gpc.is_pipeline_first_stage():
                    input_tensor_shapes[next_forward_model_chunk_id] = recv_tensor_meta(
                        input_tensor_shapes[next_forward_model_chunk_id])
            # Send and receive tensors as appropriate (send tensors computed
            # in this iteration; receive tensors for next iteration).
            input_shape = input_tensor_shapes[next_forward_model_chunk_id] if recv_prev else None
            if k == (num_warmup_microbatches - 1) and not forward_only and \
                    not all_warmup_microbatches:
                input_tensor_grad = None
                recv_next = True
                if gpc.is_pipeline_last_stage(ignore_virtual=True):
                    recv_next = False
                output_shape = output_tensor_shapes[num_model_chunks-1] if recv_next else None
                input_tensor, output_tensor_grad = \
                    send_forward_backward_recv_forward_backward(
                        output_tensor, input_tensor_grad,
                        input_shape,
                        output_shape,
                        recv_prev=recv_prev, recv_next=recv_next,
                        dtype=self.dtype)
                output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
            else:
                input_tensor = \
                    send_forward_recv_forward(
                        output_tensor,
                        input_shape,
                        recv_prev=recv_prev,
                        dtype=self.dtype)
            input_tensors[next_forward_model_chunk_id].append(input_tensor)

        # Run 1F1B in steady state.
        for k in range(num_microbatches_remaining):
            # Forward pass.
            forward_k = k + num_warmup_microbatches
            output_tensor = forward_step_helper(forward_k)

            # Backward pass.
            backward_k = k
            input_tensor_grad = backward_step_helper(backward_k)

            # Send output_tensor and input_tensor_grad, receive input_tensor
            # and output_tensor_grad.

            # Determine if current stage has anything to send in either direction,
            # otherwise set tensor to None.
            forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
            gpc.set_virtual_pipeline_parallel_rank(forward_model_chunk_id)
            if gpc.is_pipeline_last_stage():
                output_tensor = None

            backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
            gpc.set_virtual_pipeline_parallel_rank(backward_model_chunk_id)
            if gpc.is_pipeline_first_stage():
                input_tensor_grad = None

            # Determine if peers are sending, and where in data structure to put
            # received tensors.
            recv_prev = True
            if gpc.is_pipeline_first_stage(ignore_virtual=True):
                # First stage is ahead of last stage by (pipeline_parallel_size - 1).
                next_forward_model_chunk_id = get_model_chunk_id(
                    forward_k - (pipeline_parallel_size - 1), forward=True)
                if next_forward_model_chunk_id == (num_model_chunks - 1):
                    recv_prev = False
                next_forward_model_chunk_id += 1
            else:
                next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1,
                                                                 forward=True)

            recv_next = True
            if gpc.is_pipeline_last_stage(ignore_virtual=True):
                # Last stage is ahead of first stage by (pipeline_parallel_size - 1).
                next_backward_model_chunk_id = get_model_chunk_id(
                    backward_k - (pipeline_parallel_size - 1), forward=False)
                if next_backward_model_chunk_id == 0:
                    recv_next = False
                next_backward_model_chunk_id -= 1
            else:
                next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1,
                                                                  forward=False)

            # If last iteration, don't receive; we already received one extra
            # before the start of the for loop.
            if k == (num_microbatches_remaining - 1):
                recv_prev = False

            input_shape = input_tensor_shapes[next_forward_model_chunk_id] if recv_prev else None
            output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None
            # Communicate tensors.
            input_tensor, output_tensor_grad = \
                send_forward_backward_recv_forward_backward(
                    output_tensor, input_tensor_grad,
                    input_shape,
                    output_shape,
                    recv_prev=recv_prev, recv_next=recv_next,
                    dtype=self.dtype)

            # Put input_tensor and output_tensor_grad in data structures in the
            # right location.
            if recv_prev:
                input_tensors[next_forward_model_chunk_id].append(input_tensor)
            if recv_next:
                output_tensor_grads[next_backward_model_chunk_id].append(
                    output_tensor_grad)

        # Run cooldown backward passes (flush out pipeline).
        if not forward_only:
            if all_warmup_microbatches:
                output_tensor_grads[num_model_chunks-1].append(
                    recv_backward(output_tensor_shapes[num_model_chunks-1]))
            for k in range(num_microbatches_remaining, num_microbatches):
                input_tensor_grad = backward_step_helper(k)
                next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
                recv_next = True
                if gpc.is_pipeline_last_stage(ignore_virtual=True):
                    if next_backward_model_chunk_id == (num_model_chunks - 1):
                        recv_next = False
                if k == (num_microbatches - 1):
                    recv_next = False
                output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None
                output_tensor_grads[next_backward_model_chunk_id].append(
                    send_backward_recv_backward(
                        input_tensor_grad,
                        output_shape,
                        recv_next=recv_next,
                        dtype=self.dtype))

        if len(return_tensors) > 0:
            if return_loss:
                output, label, loss = tuple(map(list, zip(*return_tensors)))
                return (torch.cat(output, dim=0),
                        torch.cat(label, dim=0),
                        sum(loss))
            else:
                return tuple((torch.cat(return_tensors, dim=0), None, None))
        else:
            return tuple((None, None, None))
